# Augmentations & Regularizations for Knowledge Tracing

## Requirements
* pytorch 1.5.0
* wandb 0.8.36
* tqdm 4.46.0

### Example run
* DKT with replacement & insertion on ASSISTmentsChall dataset. Hidden dimension and embedding dimension of the model are all set to be 256. 

``
python main.py --name=assistChall_dkt_rep0.3w100_corins0.5w1up_augkt_pos_qid_b64_cv1 --augmentations rep ins 
--rep_cons_loss=1 --rep_kt_loss=1 --rep_pred=0 --rep_prob=0.3 --rep_weight=100.0 --ins_cons_loss=1 --ins_kt_loss=1 
--ins_prob=0.5 --ins_weight=1.0 --ins_response=1 --ins_loss_dir=up --data_root=/shared/ASSISTmentsChall 
--dataset_name=ASSISTmentsChall --model_type=DKT --train_batch=64 --test_batch=512 --use_wandb=0 
--enc_feature_names interaction_idx --enc_feature_dims 256 --eval_steps=100 --num_steps=5000 --device=cuda 
--num_workers=4 --gpu=0 --seq_size=200 --d_model_count=256 --head_count=8 
--dropout_rate=0.0 --split_num=1
``

## Important arguments and hyperparameters
* ``model_type``: model architecture, DKT / qDKT / DKVMN / SAINT / (SAKT)
* ``encoder_feature_names, decoder_feature_names``: features that are used. `item_idx`, `is_correct`, `interaction_idx`, `position` are available. 
DKT and DKVMN only uses encoder features, while SAINT uses both encoder (`item_idx`, `position`) and decoder (`position`, `is_correct`) features. 
* ``encoder_feature_dims, decoder_feature_dims``: embedding dimension for each features. 
* ``augmentations``: augmentation methods to be applied. rep / ins / del
* ``rep_prob, ins_prob, del_prob``: probability for each augmentations. For example, ``rep_prob=0.5`` means that about half of interactions will be replaced. 
* ``rep_weight, ins_weight, del_weight``: weights (lambda_reg-aug) for each loss
* ``rep_cons_loss, ins_cons_loss, del_cons_loss``: flag that represents whether we use consistency or monotonicity regularization or not. 
* ``rep_pred, rep_only``: flags whether we include the replaced interactions' output in the consistency loss or not. 
Default: `rep_pred='0', rep_only='0'` (False for the both flags)
* ``rep_kt_loss, ins_kt_loss, del_kt_loss``: flag that represents whether we include BCE loss for augmented sequences. (lambda_aug)
* ``ins_response. del_response``: response of the augmented interactions, ``1`` (correct), ``0`` (incorrect),  ``rand`` (random, for insertion), ``all`` (all, for deletion) 
* ``ins_loss_dir, del_loss_dir``: direction of the constraint, ``up`` means that the probability increases, and ``down`` means opposite. 
* ``lap_weight``: weight for Laplacian regularization (when `model_type` is qDKT)
* ``dataset_name``: dataset name, ASSISTmentsChall / ASSISTments2015 / STATICS2011 / EdNet-KT1 / (ASSISTments2009)
